import cv2

import os
import sys
package_path = "/system/lib"
if package_path not in sys.path:
    sys.path.insert(0,package_path)

# import sophon.sail as sail
import numpy as np
import serial

def preprocess(frames):
    batch_images = np.zeros((8, 3, 32, 32), dtype=np.float32)
    
    for i, frame in enumerate(frames):
        # Resize 
        img = cv2.resize(src=frame, dsize=(32, 32), interpolation=cv2.INTER_LINEAR)
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Normalize
        img = img.astype(np.float32)
        img -= 127.5
        img *= 1 / 127.5
        # Change from (H,W,C) to (C,H,W)
        img = img.transpose(2, 0, 1)
        batch_images[i] = img
    
    return batch_images

def postprocess(outputs):
    outputs = list(outputs.values())[0]
    pred_idx = outputs.argmax()
    return pred_idx

def calculate_bias(original_image):
    # Apply adaptive thresholding for better road detection
    adaptive_thresh = cv2.adaptiveThreshold(
        original_image,
        255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        11,
        2
    )
    
    # Apply morphological operations to clean up the image
    kernel = np.ones((5, 5), np.uint8)
    cleaned_image = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel)
    cleaned_image = cv2.morphologyEx(cleaned_image, cv2.MORPH_OPEN, kernel)
    
    height = cleaned_image.shape[0]
    width = cleaned_image.shape[1]
    
    # Calculate image midpoint
    mid_x = width // 2
    
    # Get all black pixel coordinates
    black_pixels = np.where(cleaned_image > 0)
    y_coords = black_pixels[0]
    x_coords = black_pixels[1]
    
    if len(x_coords) > 0:
        # Create points array for cv2.convexHull
        points = np.column_stack((x_coords, y_coords))
        hull = cv2.convexHull(points)
        
        # Extract x and y coordinates from hull points
        hull_points = hull.reshape(-1, 2)
        hull_x = hull_points[:, 0]
        hull_y = hull_points[:, 1]
        
        # Separate hull points into left and right based on median x-coordinate
        median_x = np.median(hull_x)
        left_mask = hull_x < median_x
        right_mask = hull_x >= median_x
        
        # Get left and right points
        left_x = hull_x[left_mask]
        left_y = hull_y[left_mask]
        right_x = hull_x[right_mask]
        right_y = hull_y[right_mask]
        
        # Fit lines
        left_fit = np.polyfit(left_y, left_x, 1)
        right_fit = np.polyfit(right_y, right_x, 1)
        
        # Calculate midline at the middle height of the image
        mid_height = height // 2
        left_x_at_mid = np.polyval(left_fit, mid_height)
        right_x_at_mid = np.polyval(right_fit, mid_height)
        midline_at_center = (left_x_at_mid + right_x_at_mid) / 2
        
        # Calculate normalized bias (-100 to 100)
        bias = 100 * (1 - 2 * (midline_at_center - mid_x) / width)
        bias = np.clip(bias, -100, 100)
        
        return bias
    else:
        return None

def infer(frame, model, graph_name, input_name):
    # Create a buffer to store 8 frames
    frames_buffer = []
    frames_buffer.append(frame)  # Add the current frame
    
    # Get 7 more frames from the camera
    cap = cv2.VideoCapture(0)
    for _ in range(7):  # We already have 1 frame, so get 7 more
        ret, new_frame = cap.read()
        if ret:
            frames_buffer.append(new_frame)
    cap.release()
    
    if len(frames_buffer) == 8:
        input_array = preprocess(frames_buffer)
        input_data = {input_name: input_array}
        outputs = model.process(graph_name, input_data)
        pred_idx = postprocess(outputs)
        
        # Calculate bias using grayscale version of the original frame
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        bias = calculate_bias(gray_frame)
        
        return pred_idx, bias
    else:
        return None, None

def infer_from_camera(model_path):
    # Load model
    model = sail.Engine(model_path, 0, sail.IOMode.SYSIO)
    graph_name = model.get_graph_names()[0]
    input_name = model.get_input_names(graph_name)[0]

    # Open the camera
    cap = cv2.VideoCapture(0)
    
    if not cap.isOpened():
        print("Error: Could not open camera.")
        return

    # Initialize frame buffer
    frames_buffer = []
    
    try:
        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            if not ret:
                print("Error: Could not read frame.")
                break

            # Add frame to buffer
            frames_buffer.append(frame)
            
            # Keep only the last 8 frames
            if len(frames_buffer) >= 8:
                # Process the 8 frames
                input_array = preprocess(frames_buffer)
                input_data = {input_name: input_array}
                outputs = model.process(graph_name, input_data)
                pred_idx = postprocess(outputs)
                
                # Calculate bias using grayscale version of the latest frame
                gray_frame = cv2.cvtColor(frames_buffer[-1], cv2.COLOR_BGR2GRAY)
                bias = calculate_bias(gray_frame)

                if pred_idx is not None and bias is not None:
                    # Configure serial port
                    port = '/dev/serial/by-id/usb-1a86_USB_Serial-if00-port0'
                    baudrate = 115200
                    ser = serial.Serial(port, baudrate, timeout=1)
    
                    if pred_idx == 2:
                        TS_class = 0
                    elif pred_idx == 0:
                        TS_class = 1
                    elif pred_idx == 1 and bias < 0:
                        TS_class = 2
                    elif pred_idx == 1 and bias >= 0:
                        TS_class = 3
                    else:
                        TS_class = 1

                    TS_pos = int(abs(bias))
                    TS_buffer = 'C' + chr(TS_class) + 'P' + chr(TS_pos)

                    for i in range(1):
                        if ser.is_open:
                            data_to_send = TS_buffer.encode('utf-8')
                            if TS_pos <= 127 and TS_class <= 127:
                                ser.write(data_to_send)
                        else:
                            print(f"无法打开串口 {port}")

                # Reset buffer to keep only the latest frame
                frames_buffer = frames_buffer[-1:]

            # Display the frame
            cv2.imshow('Camera', frame)

            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    finally:
        cap.release()
        cv2.destroyAllWindows()
        if 'ser' in locals() and ser.is_open:
            ser.close()

if __name__ == '__main__':
    model_path = 'compilation.bmodel'
    infer_from_camera(model_path)
